{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "PyTorch/XLA ResNet18/CIFAR10 Training",
      "provenance": [],
      "collapsed_sections": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YX1hxqUQn47M",
        "colab_type": "text"
      },
      "source": [
        "## PyTorch/XLA ResNet18/CIFAR10 (GPU or TPU)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pLQPoJ6Fn8wF",
        "colab_type": "text"
      },
      "source": [
        "### [RUNME] Install Colab compatible PyTorch/XLA wheels and dependencies\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "O53lrJMDn9Rd",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IednejwkIW-K",
        "colab_type": "text"
      },
      "source": [
        "Only run the below commented cell if you would like a nightly release"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "w-bBdzgeISaP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# VERSION = \"nightly\"  #@param [\"nightly\", \"20200516\"]  # or YYYYMMDD format\n",
        "# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n",
        "# !python pytorch-xla-env-setup.py --version $VERSION\n",
        "# import os \n",
        "# os.environ['LD_LIBRARY_PATH']='/usr/local/lib'\n",
        "# !echo $LD_LIBRARY_PATH\n",
        "\n",
        "# !sudo ln -s /usr/local/lib/libmkl_intel_lp64.so /usr/local/lib/libmkl_intel_lp64.so.1\n",
        "# !sudo ln -s /usr/local/lib/libmkl_intel_thread.so /usr/local/lib/libmkl_intel_thread.so.1\n",
        "# !sudo ln -s /usr/local/lib/libmkl_core.so /usr/local/lib/libmkl_core.so.1\n",
        "\n",
        "# !ldconfig\n",
        "# !ldd /usr/local/lib/python3.7/dist-packages/torch/lib/libtorch.so"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xiFzLg5gy7l6",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# PyTorch/XLA GPU Setup (only if GPU runtime)\n",
        "import os\n",
        "if os.environ.get('COLAB_GPU', '0') == '1':\n",
        "  os.environ['GPU_NUM_DEVICES'] = '1'\n",
        "  os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/usr/local/cuda/'"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rroH9yiAn-XE",
        "colab_type": "text"
      },
      "source": [
        "### Define Parameters\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cMojPWZUqr2s",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Result Visualization Helper\n",
        "from matplotlib import pyplot as plt\n",
        "\n",
        "M, N = 4, 6\n",
        "RESULT_IMG_PATH = '/tmp/test_result.jpg'\n",
        "CIFAR10_LABELS = ['airplane', 'automobile', 'bird', 'cat', 'deer',\n",
        "                 'dog', 'frog', 'horse', 'ship', 'truck']\n",
        "\n",
        "def plot_results(images, labels, preds):\n",
        "  images, labels, preds = images[:M*N], labels[:M*N], preds[:M*N]\n",
        "  inv_norm = transforms.Normalize(\n",
        "      mean=(-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010),\n",
        "      std=(1/0.2023, 1/0.1994, 1/0.2010))\n",
        "\n",
        "  num_images = images.shape[0]\n",
        "  fig, axes = plt.subplots(M, N, figsize=(16, 9))\n",
        "  fig.suptitle('Correct / Predicted Labels (Red text for incorrect ones)')\n",
        "\n",
        "  for i, ax in enumerate(fig.axes):\n",
        "    ax.axis('off')\n",
        "    if i >= num_images:\n",
        "      continue\n",
        "    img, label, prediction = images[i], labels[i], preds[i]\n",
        "    img = inv_norm(img)\n",
        "    img = img.permute(1, 2, 0) # (C, M, N) -> (M, N, C)\n",
        "    label, prediction = label.item(), prediction.item()\n",
        "    if label == prediction:\n",
        "      ax.set_title(u'\\u2713', color='blue', fontsize=22)\n",
        "    else:\n",
        "      ax.set_title(\n",
        "          'X {}/{}'.format(CIFAR10_LABELS[label],\n",
        "                          CIFAR10_LABELS[prediction]), color='red')\n",
        "    ax.imshow(img)\n",
        "  plt.savefig(RESULT_IMG_PATH, transparent=True)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iMdPRFXIn_jH",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Define Parameters\n",
        "FLAGS = {}\n",
        "FLAGS['data_dir'] = \"/tmp/cifar\"\n",
        "FLAGS['batch_size'] = 128\n",
        "FLAGS['num_workers'] = 4\n",
        "FLAGS['learning_rate'] = 0.02\n",
        "FLAGS['momentum'] = 0.9\n",
        "FLAGS['num_epochs'] = 20\n",
        "FLAGS['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1\n",
        "FLAGS['log_steps'] = 20\n",
        "FLAGS['metrics_debug'] = False"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Micd3xZvoA-c",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import numpy as np\n",
        "import os\n",
        "import time\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "import torch_xla\n",
        "import torch_xla.core.xla_model as xm\n",
        "import torch_xla.debug.metrics as met\n",
        "import torch_xla.distributed.parallel_loader as pl\n",
        "import torch_xla.distributed.xla_multiprocessing as xmp\n",
        "import torch_xla.utils.utils as xu\n",
        "import torchvision\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "class BasicBlock(nn.Module):\n",
        "  expansion = 1\n",
        "\n",
        "  def __init__(self, in_planes, planes, stride=1):\n",
        "    super(BasicBlock, self).__init__()\n",
        "    self.conv1 = nn.Conv2d(\n",
        "        in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
        "    self.bn1 = nn.BatchNorm2d(planes)\n",
        "    self.conv2 = nn.Conv2d(\n",
        "        planes, planes, kernel_size=3, stride=1, padding=1, bias=False)\n",
        "    self.bn2 = nn.BatchNorm2d(planes)\n",
        "\n",
        "    self.shortcut = nn.Sequential()\n",
        "    if stride != 1 or in_planes != self.expansion * planes:\n",
        "      self.shortcut = nn.Sequential(\n",
        "          nn.Conv2d(\n",
        "              in_planes,\n",
        "              self.expansion * planes,\n",
        "              kernel_size=1,\n",
        "              stride=stride,\n",
        "              bias=False), nn.BatchNorm2d(self.expansion * planes))\n",
        "\n",
        "  def forward(self, x):\n",
        "    out = F.relu(self.bn1(self.conv1(x)))\n",
        "    out = self.bn2(self.conv2(out))\n",
        "    out += self.shortcut(x)\n",
        "    out = F.relu(out)\n",
        "    return out\n",
        "\n",
        "\n",
        "class ResNet(nn.Module):\n",
        "\n",
        "  def __init__(self, block, num_blocks, num_classes=10):\n",
        "    super(ResNet, self).__init__()\n",
        "    self.in_planes = 64\n",
        "\n",
        "    self.conv1 = nn.Conv2d(\n",
        "        3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
        "    self.bn1 = nn.BatchNorm2d(64)\n",
        "    self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n",
        "    self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n",
        "    self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n",
        "    self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n",
        "    self.linear = nn.Linear(512 * block.expansion, num_classes)\n",
        "\n",
        "  def _make_layer(self, block, planes, num_blocks, stride):\n",
        "    strides = [stride] + [1] * (num_blocks - 1)\n",
        "    layers = []\n",
        "    for stride in strides:\n",
        "      layers.append(block(self.in_planes, planes, stride))\n",
        "      self.in_planes = planes * block.expansion\n",
        "    return nn.Sequential(*layers)\n",
        "\n",
        "  def forward(self, x):\n",
        "    out = F.relu(self.bn1(self.conv1(x)))\n",
        "    out = self.layer1(out)\n",
        "    out = self.layer2(out)\n",
        "    out = self.layer3(out)\n",
        "    out = self.layer4(out)\n",
        "    out = F.avg_pool2d(out, 4)\n",
        "    out = torch.flatten(out, 1)\n",
        "    out = self.linear(out)\n",
        "    return F.log_softmax(out, dim=1)\n",
        "\n",
        "\n",
        "def ResNet18():\n",
        "  return ResNet(BasicBlock, [2, 2, 2, 2])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8vMl96KLoCq8",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "SERIAL_EXEC = xmp.MpSerialExecutor()\n",
        "# Only instantiate model weights once in memory.\n",
        "WRAPPED_MODEL = xmp.MpModelWrapper(ResNet18())\n",
        "\n",
        "def train_resnet18():\n",
        "  torch.manual_seed(1)\n",
        "\n",
        "  def get_dataset():\n",
        "    norm = transforms.Normalize(\n",
        "        mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))\n",
        "    transform_train = transforms.Compose([\n",
        "        transforms.RandomCrop(32, padding=4),\n",
        "        transforms.RandomHorizontalFlip(),\n",
        "        transforms.ToTensor(),\n",
        "        norm,\n",
        "    ])\n",
        "    transform_test = transforms.Compose([\n",
        "        transforms.ToTensor(),\n",
        "        norm,\n",
        "    ])\n",
        "    train_dataset = datasets.CIFAR10(\n",
        "        root=FLAGS['data_dir'],\n",
        "        train=True,\n",
        "        download=True,\n",
        "        transform=transform_train)\n",
        "    test_dataset = datasets.CIFAR10(\n",
        "        root=FLAGS['data_dir'],\n",
        "        train=False,\n",
        "        download=True,\n",
        "        transform=transform_test)\n",
        "    \n",
        "    return train_dataset, test_dataset\n",
        "  \n",
        "  # Using the serial executor avoids multiple processes\n",
        "  # to download the same data.\n",
        "  train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset)\n",
        "\n",
        "  train_sampler = torch.utils.data.distributed.DistributedSampler(\n",
        "      train_dataset,\n",
        "      num_replicas=xm.xrt_world_size(),\n",
        "      rank=xm.get_ordinal(),\n",
        "      shuffle=True)\n",
        "  train_loader = torch.utils.data.DataLoader(\n",
        "      train_dataset,\n",
        "      batch_size=FLAGS['batch_size'],\n",
        "      sampler=train_sampler,\n",
        "      num_workers=FLAGS['num_workers'],\n",
        "      drop_last=True)\n",
        "  test_loader = torch.utils.data.DataLoader(\n",
        "      test_dataset,\n",
        "      batch_size=FLAGS['batch_size'],\n",
        "      shuffle=False,\n",
        "      num_workers=FLAGS['num_workers'],\n",
        "      drop_last=True)\n",
        "\n",
        "  # Scale learning rate to num cores\n",
        "  learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()\n",
        "\n",
        "  # Get loss function, optimizer, and model\n",
        "  device = xm.xla_device()\n",
        "  model = WRAPPED_MODEL.to(device)\n",
        "  optimizer = optim.SGD(model.parameters(), lr=learning_rate,\n",
        "                        momentum=FLAGS['momentum'], weight_decay=5e-4)\n",
        "  loss_fn = nn.NLLLoss()\n",
        "\n",
        "  def train_loop_fn(loader):\n",
        "    tracker = xm.RateTracker()\n",
        "    model.train()\n",
        "    for x, (data, target) in enumerate(loader):\n",
        "      optimizer.zero_grad()\n",
        "      output = model(data)\n",
        "      loss = loss_fn(output, target)\n",
        "      loss.backward()\n",
        "      xm.optimizer_step(optimizer)\n",
        "      tracker.add(FLAGS['batch_size'])\n",
        "      if x % FLAGS['log_steps'] == 0:\n",
        "        print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(\n",
        "            xm.get_ordinal(), x, loss.item(), tracker.rate(),\n",
        "            tracker.global_rate(), time.asctime()), flush=True)\n",
        "\n",
        "  def test_loop_fn(loader):\n",
        "    total_samples = 0\n",
        "    correct = 0\n",
        "    model.eval()\n",
        "    data, pred, target = None, None, None\n",
        "    for data, target in loader:\n",
        "      output = model(data)\n",
        "      pred = output.max(1, keepdim=True)[1]\n",
        "      correct += pred.eq(target.view_as(pred)).sum().item()\n",
        "      total_samples += data.size()[0]\n",
        "\n",
        "    accuracy = 100.0 * correct / total_samples\n",
        "    print('[xla:{}] Accuracy={:.2f}%'.format(\n",
        "        xm.get_ordinal(), accuracy), flush=True)\n",
        "    return accuracy, data, pred, target\n",
        "\n",
        "  # Train and eval loops\n",
        "  accuracy = 0.0\n",
        "  data, pred, target = None, None, None\n",
        "  for epoch in range(1, FLAGS['num_epochs'] + 1):\n",
        "    para_loader = pl.ParallelLoader(train_loader, [device])\n",
        "    train_loop_fn(para_loader.per_device_loader(device))\n",
        "    xm.master_print(\"Finished training epoch {}\".format(epoch))\n",
        "\n",
        "    para_loader = pl.ParallelLoader(test_loader, [device])\n",
        "    accuracy, data, pred, target  = test_loop_fn(para_loader.per_device_loader(device))\n",
        "    if FLAGS['metrics_debug']:\n",
        "      xm.master_print(met.metrics_report(), flush=True)\n",
        "\n",
        "  return accuracy, data, pred, target\n",
        "  "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_2nL4HmloEyl",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Start training processes\n",
        "def _mp_fn(rank, flags):\n",
        "  global FLAGS\n",
        "  FLAGS = flags\n",
        "  torch.set_default_tensor_type('torch.FloatTensor')\n",
        "  accuracy, data, pred, target = train_resnet18()\n",
        "  if rank == 0:\n",
        "    # Retrieve tensors that are on TPU core 0 and plot.\n",
        "    plot_results(data.cpu(), pred.cpu(), target.cpu())\n",
        "\n",
        "xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'],\n",
        "          start_method='fork')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_wt7wEVJoFmf",
        "colab_type": "text"
      },
      "source": [
        "## Visualize Predictions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mSdVUMPjoGhy",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from google.colab.patches import cv2_imshow\n",
        "import cv2\n",
        "img = cv2.imread(RESULT_IMG_PATH, cv2.IMREAD_UNCHANGED)\n",
        "cv2_imshow(img)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}
